-
Notifications
You must be signed in to change notification settings - Fork 13.6k
TypeTree support in autodiff #144197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
TypeTree support in autodiff #144197
Conversation
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
Currently, I have implemented only for |
r? @ZuseZ4 |
Some changes occurred in compiler/rustc_ast/src/expand/autodiff_attrs.rs cc @ZuseZ4 Some changes occurred in compiler/rustc_codegen_llvm/src/builder/autodiff.rs cc @ZuseZ4 Some changes occurred in compiler/rustc_codegen_ssa Some changes occurred in compiler/rustc_monomorphize/src/partitioning/autodiff.rs cc @ZuseZ4 |
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
Some changes occurred in compiler/rustc_codegen_gcc |
This comment has been minimized.
This comment has been minimized.
CI is failing, fixing them! |
Some changes occurred in src/tools/enzyme cc @ZuseZ4 |
This comment has been minimized.
This comment has been minimized.
d3f78d7
to
94aa8bc
Compare
This comment has been minimized.
This comment has been minimized.
☔ The latest upstream changes (presumably #143684) made this pull request unmergeable. Please resolve the merge conflicts. |
83aa06f
to
a28f311
Compare
This comment has been minimized.
This comment has been minimized.
use run_make_support::{llvm_filecheck, rfs, rustc}; | ||
|
||
fn main() { | ||
// First, compile to LLVM IR to check for enzyme_type attributes | ||
let _ir_output = rustc() | ||
.input("memcpy.rs") | ||
.arg("-Zautodiff=Enable") | ||
.arg("-Zautodiff=NoPostopt") | ||
.opt_level("3") | ||
.arg("-Clto=fat") | ||
.arg("--emit=llvm-ir") | ||
.arg("-o") | ||
.arg("main.ll") | ||
.run(); | ||
|
||
// Then compile with TypeTree analysis output for the existing checks | ||
let output = rustc() | ||
.input("memcpy.rs") | ||
.arg("-Zautodiff=Enable,PrintTAFn=test_memcpy") | ||
.arg("-Zautodiff=NoPostopt") | ||
.opt_level("3") | ||
.arg("-Clto=fat") | ||
.arg("-g") | ||
.run(); | ||
|
||
let stdout = output.stdout_utf8(); | ||
let stderr = output.stderr_utf8(); | ||
let ir_content = rfs::read_to_string("main.ll"); | ||
|
||
rfs::write("memcpy.stdout", &stdout); | ||
rfs::write("memcpy.stderr", &stderr); | ||
rfs::write("main.ir", &ir_content); | ||
|
||
llvm_filecheck().patterns("memcpy.check").stdin_buf(stdout).run(); | ||
|
||
llvm_filecheck().patterns("memcpy-ir.check").stdin_buf(ir_content).run(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jieyouxu can you review this run make test, i am creating two check files to test IR and type analysis from enzyme, is this correct way or i should be combining them in one file with one single check?
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
110292b
to
cffae78
Compare
☔ The latest upstream changes (presumably #145300) made this pull request unmergeable. Please resolve the merge conflicts. |
@@ -2194,3 +2201,335 @@ pub struct DestructuredConst<'tcx> { | |||
pub variant: Option<VariantIdx>, | |||
pub fields: &'tcx [ty::Const<'tcx>], | |||
} | |||
|
|||
// Some types are used a lot. Make sure they don't unintentionally get bigger. | |||
#[cfg(target_pointer_width = "64")] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's probably code from someone else, can we drop it?
} | ||
|
||
#[cfg(not(llvm_enzyme))] | ||
#[allow(dead_code)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ping
@@ -1253,6 +1253,9 @@ pub struct Resolver<'ra, 'tcx> { | |||
// that were encountered during resolution. These names are used to generate item names | |||
// for APITs, so we don't want to leak details of resolution into these names. | |||
impl_trait_names: FxHashMap<NodeId, Symbol>, | |||
|
|||
/// Mapping of autodiff function IDs | |||
autodiff_map: FxHashMap<LocalDefId, LocalDefId>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually use the autodiff_map anywhere?
_ => panic!(""), | ||
}; | ||
|
||
let fields = adt_def.all_fields(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let fields = adt_def
.all_fields()
.into_iter()
.zip(offsets.into_iter())
if inner_ty.is_slice() { | ||
// We know that the length will be passed as extra arg. | ||
let child = typetree_from_ty(inner_ty, tcx, 1, safety, &mut visited, span); | ||
let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think on 32 bit targets we probably have size: 4 here, can you check that (e.g. on the playground) and make it selective if needed?
FncTree { args, ret } | ||
} | ||
|
||
fn typetree_from_ty<'a>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Iirc I added the visited Vec here to detect cycles (recursive types). Can you verify that and add a comment describing what it does?
assert!(span.is_some()); | ||
let span = span.unwrap(); | ||
|
||
tcx.sess.dcx().emit_warn(AutodiffUnsafeInnerConstRef { span, ty }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also have a test
match ty { | ||
x if x == tcx.types.f32 => (Kind::Float, 4), | ||
x if x == tcx.types.f64 => (Kind::Double, 8), | ||
_ => panic!("floatTy scalar that is neither f32 nor f64"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we support f16 and f128 now, can you add those (And some tests)?
// Not an error, because it only causes issues if they are actually read, which we don't check | ||
// yet. We should add such analysis to relibably either issue an error or accept without warning. | ||
// If there only were some research to do that... | ||
pub fn fnc_typetrees<'tcx>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fn is comparably complex, so we should have tests which cover every case handled here, or you split some of the logic out in follow-up PRs.
I would probably squash the current PR into 1-3 commits to start and only add a very minimal handling to fnc_typtrees (and other functions in this file). Then you can add follow-up commits for the extra handling. E.g. one commit for array handling + array tests, one commit for simd and the simd tests, one for recursive handling, and one for the InnerConstRef handling, etc. This way we can merge it incrementally and we're sure that every piece in this function actually works and is tested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was thinking about the same, should i do follow PRs adding new handling and tests or we have some basic typetree requirement which this PR must complete, so this can be merged.
What should be goal of the current PR ? basic typetree support, working memcpy
with tests and i think NOTT
flag, right?
You should also add a Flag which disables all TypeTree additions (in case that it causes bugs, or just for A/B testing to see where the typetree's allow us to compile something, or have a compile time impact). We already have autodiff=Enable, just add another option to that enum, so people can pass |
TypeTrees for Autodiff
What are TypeTrees?
Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently.
Structure
Example:
fn compute(x: &f32, data: &[f32]) -> f32
Input 0:
x: &f32
Input 1:
data: &[f32]
Output:
f32
Why Needed?
What Enzyme Does With This Information:
Without TypeTrees (current state):
With TypeTrees (our goal):
TypeTrees - Offset and -1 Explained
Type Structure
Offset Values
Regular Offset (0, 4, 8, etc.)
Specific byte position within a structure
TypeTree for
&Point
:Offset -1 (Special: "Everywhere")
Means "this pattern repeats for ALL elements"
Example 1: Array
[f32; 100]
Instead of listing 100 separate Types with offsets 0,4,8,12...396
Example 2: Slice
&[i32]
Example 3: Mixed Structure